{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "QgumzOZ5wgec",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Evaluate Data Drift for train/test datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "rByuPhg7wgei",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import joblib\n",
    "import mlflow\n",
    "import mlflow.sklearn\n",
    "from mlflow.tracking import MlflowClient\n",
    "import pandas as pd\n",
    "\n",
    "from pathlib import Path\n",
    "from sklearn import ensemble\n",
    "from typing import Dict\n",
    "\n",
    "from evidently.pipeline.column_mapping import ColumnMapping\n",
    "\n",
    "from config import MLFLOW_TRACKING_URI, DATA_DIR, FILENAME, REPORTS_DIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "HiiUl3p8wgej",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "warnings.simplefilter('ignore')"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "zw5Tap_Xwgej",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Load Data"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "VqGH1Mr6wgej",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "More information about the dataset can be found in UCI machine learning repository: https://archive.ics.uci.edu/ml/datasets/bike+sharing+dataset\n",
    "\n",
    "Acknowledgement: Fanaee-T, Hadi, and Gama, Joao, 'Event labeling combining ensemble detectors and background knowledge', Progress in Artificial Intelligence (2013): pp. 1-15, Springer Berlin Heidelberg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "36Gk-YMhwgek",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>instant</th>\n",
       "      <th>dteday</th>\n",
       "      <th>season</th>\n",
       "      <th>yr</th>\n",
       "      <th>mnth</th>\n",
       "      <th>hr</th>\n",
       "      <th>holiday</th>\n",
       "      <th>weekday</th>\n",
       "      <th>workingday</th>\n",
       "      <th>weathersit</th>\n",
       "      <th>temp</th>\n",
       "      <th>atemp</th>\n",
       "      <th>hum</th>\n",
       "      <th>windspeed</th>\n",
       "      <th>casual</th>\n",
       "      <th>registered</th>\n",
       "      <th>cnt</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>2011-01-01</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.2879</td>\n",
       "      <td>0.81</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3</td>\n",
       "      <td>13</td>\n",
       "      <td>16</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>2011-01-01</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.22</td>\n",
       "      <td>0.2727</td>\n",
       "      <td>0.80</td>\n",
       "      <td>0.0</td>\n",
       "      <td>8</td>\n",
       "      <td>32</td>\n",
       "      <td>40</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>2011-01-01</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.22</td>\n",
       "      <td>0.2727</td>\n",
       "      <td>0.80</td>\n",
       "      <td>0.0</td>\n",
       "      <td>5</td>\n",
       "      <td>27</td>\n",
       "      <td>32</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>2011-01-01</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.2879</td>\n",
       "      <td>0.75</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3</td>\n",
       "      <td>10</td>\n",
       "      <td>13</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>2011-01-01</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.2879</td>\n",
       "      <td>0.75</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   instant      dteday  season  yr  mnth  hr  holiday  weekday  workingday  \\\n",
       "0        1  2011-01-01       1   0     1   0        0        6           0   \n",
       "1        2  2011-01-01       1   0     1   1        0        6           0   \n",
       "2        3  2011-01-01       1   0     1   2        0        6           0   \n",
       "3        4  2011-01-01       1   0     1   3        0        6           0   \n",
       "4        5  2011-01-01       1   0     1   4        0        6           0   \n",
       "\n",
       "   weathersit  temp   atemp   hum  windspeed  casual  registered  cnt  \n",
       "0           1  0.24  0.2879  0.81        0.0       3          13   16  \n",
       "1           1  0.22  0.2727  0.80        0.0       8          32   40  \n",
       "2           1  0.22  0.2727  0.80        0.0       5          27   32  \n",
       "3           1  0.24  0.2879  0.75        0.0       3          10   13  \n",
       "4           1  0.24  0.2879  0.75        0.0       0           1    1  "
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Download original dataset with: python src/pipelines/load_data.py \n",
    "\n",
    "raw_data = pd.read_csv(f\"../{DATA_DIR}/{FILENAME}\")\n",
    "\n",
    "raw_data.head()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "dhZOCJZ1wgel",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Define column mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "_i8edS6Ewgem",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "target = 'cnt'\n",
    "prediction = 'prediction'\n",
    "datetime = 'dteday'\n",
    "numerical_features = ['temp', 'atemp', 'hum', 'windspeed', 'mnth', 'hr', 'weekday']\n",
    "categorical_features = ['season', 'holiday', 'workingday', ]\n",
    "\n",
    "column_mapping = ColumnMapping()\n",
    "column_mapping.target = target\n",
    "column_mapping.prediction = prediction\n",
    "column_mapping.datetime = datetime\n",
    "column_mapping.numerical_features = numerical_features\n",
    "column_mapping.categorical_features = categorical_features"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "id": "q5lW24Xzwgex",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Define the comparison windows"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "L6PKtAGEwgey",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "start_date_0 = '2011-01-02 00:00:00'\n",
    "end_date_0 = '2011-01-30 23:00:00'\n",
    "\n",
    "experiment_batches = [\n",
    "    \n",
    "    ('2011-01-31 00:00:00','2011-02-06 23:00:00'),\n",
    "    ('2011-02-07 23:00:00','2011-02-13 23:00:00'),\n",
    "    ('2011-02-14 23:00:00','2011-02-20 23:00:00'),\n",
    "    ('2011-02-21 00:00:00','2011-02-27 23:00:00'),\n",
    "    ('2011-02-28 00:00:00','2011-03-06 23:00:00'),  \n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define the Reference data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(617, 16)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>instant</th>\n",
       "      <th>season</th>\n",
       "      <th>yr</th>\n",
       "      <th>mnth</th>\n",
       "      <th>hr</th>\n",
       "      <th>holiday</th>\n",
       "      <th>weekday</th>\n",
       "      <th>workingday</th>\n",
       "      <th>weathersit</th>\n",
       "      <th>temp</th>\n",
       "      <th>atemp</th>\n",
       "      <th>hum</th>\n",
       "      <th>windspeed</th>\n",
       "      <th>casual</th>\n",
       "      <th>registered</th>\n",
       "      <th>cnt</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>dteday</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>2011-01-03</th>\n",
       "      <td>48</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.22</td>\n",
       "      <td>0.1970</td>\n",
       "      <td>0.44</td>\n",
       "      <td>0.3582</td>\n",
       "      <td>0</td>\n",
       "      <td>5</td>\n",
       "      <td>5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2011-01-03</th>\n",
       "      <td>49</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.20</td>\n",
       "      <td>0.1667</td>\n",
       "      <td>0.44</td>\n",
       "      <td>0.4179</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2011-01-03</th>\n",
       "      <td>50</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.1364</td>\n",
       "      <td>0.47</td>\n",
       "      <td>0.3881</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2011-01-03</th>\n",
       "      <td>51</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.16</td>\n",
       "      <td>0.1364</td>\n",
       "      <td>0.47</td>\n",
       "      <td>0.2836</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2011-01-03</th>\n",
       "      <td>52</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>0.14</td>\n",
       "      <td>0.1061</td>\n",
       "      <td>0.50</td>\n",
       "      <td>0.3881</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>30</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "            instant  season  yr  mnth  hr  holiday  weekday  workingday  \\\n",
       "dteday                                                                    \n",
       "2011-01-03       48       1   0     1   0        0        1           1   \n",
       "2011-01-03       49       1   0     1   1        0        1           1   \n",
       "2011-01-03       50       1   0     1   4        0        1           1   \n",
       "2011-01-03       51       1   0     1   5        0        1           1   \n",
       "2011-01-03       52       1   0     1   6        0        1           1   \n",
       "\n",
       "            weathersit  temp   atemp   hum  windspeed  casual  registered  cnt  \n",
       "dteday                                                                          \n",
       "2011-01-03           1  0.22  0.1970  0.44     0.3582       0           5    5  \n",
       "2011-01-03           1  0.20  0.1667  0.44     0.4179       0           2    2  \n",
       "2011-01-03           1  0.16  0.1364  0.47     0.3881       0           1    1  \n",
       "2011-01-03           1  0.16  0.1364  0.47     0.2836       0           3    3  \n",
       "2011-01-03           1  0.14  0.1061  0.50     0.3881       0          30   30  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Set datetime index \n",
    "raw_data = raw_data.set_index('dteday')\n",
    "\n",
    "# Define the reference dataset\n",
    "reference = raw_data.loc[start_date_0:end_date_0]\n",
    "\n",
    "print(reference.shape)\n",
    "reference.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.reports import (\n",
    "    build_regression_quality_report,\n",
    "    get_regression_quality_metrics,\n",
    "    build_data_drift_report,\n",
    "    get_data_drift_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023/08/14 13:21:36 INFO mlflow.tracking.fluent: Experiment with name 'Data Drift' does not exist. Creating a new experiment.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Client tracking uri: http://localhost:5000\n"
     ]
    }
   ],
   "source": [
    "# Set up MLFlow Client\n",
    "mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)\n",
    "client = MlflowClient()\n",
    "print(f\"Client tracking uri: {client.tracking_uri}\")\n",
    "\n",
    "# Set experiment name\n",
    "mlflow.set_experiment(\"Data Drift\")\n",
    "\n",
    "# Set experiment variables\n",
    "model_path = Path('../models/model.joblib')\n",
    "ref_end_data = end_date_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Experiment id: 501037242071821375\n",
      "Run id: 8aea2cf44c8e4aad8a4e606560dbf8de\n",
      "Run name: beautiful-fox-682\n",
      "Batch: 0\n",
      "Train dates: ('2011-01-02 00:00:00', '2011-01-30 23:00:00')\n",
      "Test (current) dates: ('2011-01-31 00:00:00', '2011-02-06 23:00:00')\n",
      "X_train (reference) dataset shape:  (617, 10) (617,)\n",
      "X_test (current)) dataset shape:  (141, 10) (141,)\n",
      "Run id: 9f35f9bd62b54f5aa853fbd341d257a4\n",
      "Run name: 2011-02-06 23:00:00\n",
      "Batch: 1\n",
      "Train dates: ('2011-01-02 00:00:00', '2011-02-06 23:00:00')\n",
      "Test (current) dates: ('2011-02-07 23:00:00', '2011-02-13 23:00:00')\n",
      "X_train (reference) dataset shape:  (782, 10) (782,)\n",
      "X_test (current)) dataset shape:  (139, 10) (139,)\n",
      "Run id: 26a93a6a2e4843f3b734806d9e94cf5e\n",
      "Run name: 2011-02-13 23:00:00\n",
      "Batch: 2\n",
      "Train dates: ('2011-01-02 00:00:00', '2011-02-13 23:00:00')\n",
      "Test (current) dates: ('2011-02-14 23:00:00', '2011-02-20 23:00:00')\n",
      "X_train (reference) dataset shape:  (945, 10) (945,)\n",
      "X_test (current)) dataset shape:  (141, 10) (141,)\n",
      "Run id: 87b860de9a5547b29273c2e72d157c3a\n",
      "Run name: 2011-02-20 23:00:00\n",
      "Batch: 3\n",
      "Train dates: ('2011-01-02 00:00:00', '2011-02-20 23:00:00')\n",
      "Test (current) dates: ('2011-02-21 00:00:00', '2011-02-27 23:00:00')\n",
      "X_train (reference) dataset shape:  (1110, 10) (1110,)\n",
      "X_test (current)) dataset shape:  (134, 10) (134,)\n",
      "Run id: b5f35dd018c14c33b28ade0719bf7217\n",
      "Run name: 2011-02-27 23:00:00\n",
      "Batch: 4\n",
      "Train dates: ('2011-01-02 00:00:00', '2011-02-27 23:00:00')\n",
      "Test (current) dates: ('2011-02-28 00:00:00', '2011-03-06 23:00:00')\n",
      "X_train (reference) dataset shape:  (1268, 10) (1268,)\n",
      "X_test (current)) dataset shape:  (143, 10) (143,)\n",
      "Run id: acf32f79382146058aa0f915cca02239\n",
      "Run name: 2011-03-06 23:00:00\n"
     ]
    }
   ],
   "source": [
    "\n",
    "ref_end_data = end_date_0\n",
    "FEATURE_COLUMNS = numerical_features + categorical_features\n",
    "\n",
    "# Start a new Run (Parent Run)\n",
    "with mlflow.start_run() as run: \n",
    "    \n",
    "    # Show newly created run metadata info\n",
    "    print(\"Experiment id: {}\".format(run.info.experiment_id))\n",
    "    print(\"Run id: {}\".format(run.info.run_id))\n",
    "    print(\"Run name: {}\".format(run.info.run_name))\n",
    "    run_id = run.info.run_id\n",
    "    \n",
    "    # Save every fold metrics to a single object\n",
    "    metrics_model = {}  # Model Quality metrics on train  (averaged)\n",
    "    metrics_data = {}   # Data Drift metrics (averaged)\n",
    "\n",
    "    # Run model train for each batch (K-Fold)\n",
    "    for k, test_dates in enumerate(experiment_batches):\n",
    "        \n",
    "        print(f\"Batch: {k}\")\n",
    "        \n",
    "        train_dates = start_date_0, ref_end_data\n",
    "        ref_end_data = test_dates[1] # Update reference end date for the next train batch \n",
    "        print(f\"Train dates: {train_dates}\") \n",
    "        print(f\"Test (current) dates: {test_dates}\") \n",
    "        \n",
    "        train_data = raw_data.loc[train_dates[0]:train_dates[1]]\n",
    "        X_train = train_data.loc[:, FEATURE_COLUMNS]\n",
    "        y_train = train_data.loc[:, target]\n",
    "        print(\"X_train (reference) dataset shape: \", X_train.shape, y_train.shape)\n",
    "        \n",
    "        test_data = raw_data.loc[test_dates[0]:test_dates[1]]\n",
    "        X_test = test_data.loc[:, FEATURE_COLUMNS]\n",
    "        y_test = test_data[target]\n",
    "        print(\"X_test (current)) dataset shape: \",  X_test.shape, y_test.shape)\n",
    "        \n",
    "        # Train model\n",
    "        regressor = ensemble.RandomForestRegressor(random_state = 0, n_estimators = 50)\n",
    "        regressor.fit(X_train, y_train)\n",
    "        \n",
    "        # Make predictions \n",
    "        ref_prediction = regressor.predict(train_data[FEATURE_COLUMNS])\n",
    "        train_data['prediction'] = ref_prediction\n",
    "        cur_prediction = regressor.predict(test_data[FEATURE_COLUMNS])\n",
    "        test_data['prediction'] = cur_prediction\n",
    "        \n",
    "        # Calculate Model Quality metrics\n",
    "        regression_quality_report = build_regression_quality_report(\n",
    "            reference_data=train_data, \n",
    "            current_data=test_data,\n",
    "            column_mapping=column_mapping\n",
    "        )\n",
    "        \n",
    "        # Extract Metrics from the  the Model Quality report\n",
    "        train_metrics = get_regression_quality_metrics(regression_quality_report)\n",
    "        metrics_model.update({test_dates[1]: train_metrics})\n",
    "\n",
    "        # Calculate Data Drift metrics\n",
    "        data_drift_report = build_data_drift_report(\n",
    "            reference_data=X_train.reset_index(drop=True), \n",
    "            current_data=X_test.reset_index(drop=True),\n",
    "            column_mapping=column_mapping,\n",
    "            drift_share=0.4\n",
    "        )\n",
    "        data_drift_metrics: Dict = get_data_drift_metrics(data_drift_report)\n",
    "        metrics_data.update({test_dates[1]: data_drift_metrics})\n",
    "        \n",
    "        # Run a Child Run for each Fold \n",
    "        with mlflow.start_run(run_name=test_dates[1], nested=True) as nested_run:\n",
    "            \n",
    "            # Show newly created run metadata info\n",
    "            print(\"Run id: {}\".format(nested_run.info.run_id))\n",
    "            print(\"Run name: {}\".format(nested_run.info.run_name))\n",
    "\n",
    "            # Log parameters\n",
    "            mlflow.log_param(\"begin\", test_dates[0])\n",
    "            mlflow.log_param(\"end\", test_dates[1])\n",
    "            \n",
    "            # Log metrics\n",
    "            mlflow.log_metrics(train_metrics)\n",
    "            mlflow.log_metrics(data_drift_metrics)\n",
    "            \n",
    "            # Save the Model Quality report & Log  \n",
    "            model_quality_report_path = f\"../{REPORTS_DIR}/model_quality_report.html\"\n",
    "            regression_quality_report.save_html(model_quality_report_path)\n",
    "            mlflow.log_artifact(model_quality_report_path)\n",
    "            \n",
    "            # Log Data Drift report ONLY if drift is detected\n",
    "            if data_drift_metrics['dataset_drift'] is True:\n",
    "                report_date = test_dates[1].split(' ')[0]\n",
    "                data_drift_report_path = f\"../{REPORTS_DIR}/data_drift_report_{report_date}.html\"\n",
    "                data_drift_report.save_html(data_drift_report_path)\n",
    "                mlflow.log_artifact(data_drift_report_path)\n",
    "    \n",
    "    # Save final model\n",
    "    joblib.dump(regressor, model_path)\n",
    "    \n",
    "    # Log the last batch model as the parent Run model\n",
    "    mlflow.log_artifact(model_path)\n",
    "    \n",
    "    # Log metrics\n",
    "    avg_model_metrics = pd.DataFrame.from_dict(metrics_model).T.mean().round(3).to_dict()\n",
    "    mlflow.log_metrics(avg_model_metrics)\n",
    "    \n",
    "    avg_data_metrics = pd.DataFrame.from_dict(metrics_data).T.mean().round(3).to_dict()\n",
    "    mlflow.log_metrics(avg_data_metrics)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
